{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5a67ca73-04e8-46c2-bb83-5c02dbe15abb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://mirrors.aliyun.com/pypi/simple\n",
      "Requirement already satisfied: pandas in /root/miniconda3/lib/python3.8/site-packages (2.0.3)\n",
      "Requirement already satisfied: tzdata>=2022.1 in /root/miniconda3/lib/python3.8/site-packages (from pandas) (2024.1)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /root/miniconda3/lib/python3.8/site-packages (from pandas) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /root/miniconda3/lib/python3.8/site-packages (from pandas) (2021.1)\n",
      "Requirement already satisfied: numpy>=1.20.3 in /root/miniconda3/lib/python3.8/site-packages (from pandas) (1.21.2)\n",
      "Requirement already satisfied: six>=1.5 in /root/miniconda3/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "! pip install pandas\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn.functional as F\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "19ae53ec-aa1f-40bc-bd0f-f39aa868019c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 定义LSTM模型\n",
    "class LSTMModel(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_layers, num_classes):\n",
    "        super(LSTMModel, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n",
    "        self.fc = nn.Linear(hidden_size, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)\n",
    "        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)\n",
    "        out, _ = self.lstm(x, (h0, c0))\n",
    "        out = self.fc(out[:, -1, :])\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6b2b7f8c-3a21-4db7-a521-abd94a361555",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 自定义数据集类\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, csv_file):\n",
    "        self.data = pd.read_csv(csv_file)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sequence = list(map(int, str(self.data.iloc[idx, 0]).split(', '))) # 将字符串序列转换为整数列表\n",
    "        label = int(self.data.iloc[idx, 1]) # 获取标签\n",
    "        return torch.tensor(sequence), torch.tensor(label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "03c95726-9acf-4a98-9555-c2e331e2df30",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 超参数设置\n",
    "input_size = 1 # 输入特征维度（每个行为序列中的元素个数）\n",
    "hidden_size = 128 # LSTM隐藏单元数\n",
    "num_layers = 2 # LSTM层数\n",
    "num_classes = 3 # 分类类别数\n",
    "batch_size = 4\n",
    "num_epochs = 20\n",
    "learning_rate = 0.0001\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "999f1cd1-e1d3-4bf1-936e-7f5772caef48",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def collate_fn(batch):\n",
    "    sequences, labels = zip(*batch)\n",
    "    # 找到最长的序列长度\n",
    "    max_length = max([len(seq) for seq in sequences])\n",
    "    # 填充序列\n",
    "    sequences_padded = [torch.cat([seq, torch.zeros(max_length - len(seq)).long()]) for seq in sequences]\n",
    "    return torch.stack(sequences_padded, dim=0), torch.tensor(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "089a9ce6-2f6a-4439-9985-0d97b02dc655",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 加载数据集\n",
    "dataset = CustomDataset('/root/LSTM2/train.csv')  # 修改为Excel文件的路径\n",
    "\n",
    "\n",
    "\n",
    "loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)\n",
    "\n",
    "# loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "574c04dd-945c-4af9-b031-7f6a098d2b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 初始化模型、损失函数和优化器\n",
    "model = LSTMModel(input_size, hidden_size, num_layers, num_classes).to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "26249f47-2762-4eba-93c8-4e79c1e1ad91",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Step [20/99], Loss: 1.0757\n",
      "Epoch [1/20], Step [40/99], Loss: 1.0249\n",
      "Epoch [1/20], Step [60/99], Loss: 0.9297\n",
      "Epoch [1/20], Step [80/99], Loss: 0.7434\n",
      "Epoch [2/20], Step [20/99], Loss: 1.0474\n",
      "Epoch [2/20], Step [40/99], Loss: 1.0273\n",
      "Epoch [2/20], Step [60/99], Loss: 1.3939\n",
      "Epoch [2/20], Step [80/99], Loss: 1.1131\n",
      "Epoch [3/20], Step [20/99], Loss: 0.6065\n",
      "Epoch [3/20], Step [40/99], Loss: 0.9430\n",
      "Epoch [3/20], Step [60/99], Loss: 1.7961\n",
      "Epoch [3/20], Step [80/99], Loss: 0.2311\n",
      "Epoch [4/20], Step [20/99], Loss: 0.7275\n",
      "Epoch [4/20], Step [40/99], Loss: 0.3173\n",
      "Epoch [4/20], Step [60/99], Loss: 0.8776\n",
      "Epoch [4/20], Step [80/99], Loss: 1.1190\n",
      "Epoch [5/20], Step [20/99], Loss: 0.9050\n",
      "Epoch [5/20], Step [40/99], Loss: 0.4028\n",
      "Epoch [5/20], Step [60/99], Loss: 0.4220\n",
      "Epoch [5/20], Step [80/99], Loss: 0.6530\n",
      "Epoch [6/20], Step [20/99], Loss: 0.7924\n",
      "Epoch [6/20], Step [40/99], Loss: 0.2966\n",
      "Epoch [6/20], Step [60/99], Loss: 0.2461\n",
      "Epoch [6/20], Step [80/99], Loss: 0.9683\n",
      "Epoch [7/20], Step [20/99], Loss: 0.8509\n",
      "Epoch [7/20], Step [40/99], Loss: 0.4120\n",
      "Epoch [7/20], Step [60/99], Loss: 0.9035\n",
      "Epoch [7/20], Step [80/99], Loss: 0.6890\n",
      "Epoch [8/20], Step [20/99], Loss: 0.2506\n",
      "Epoch [8/20], Step [40/99], Loss: 0.4719\n",
      "Epoch [8/20], Step [60/99], Loss: 0.6054\n",
      "Epoch [8/20], Step [80/99], Loss: 1.4887\n",
      "Epoch [9/20], Step [20/99], Loss: 0.5834\n",
      "Epoch [9/20], Step [40/99], Loss: 0.8767\n",
      "Epoch [9/20], Step [60/99], Loss: 0.3573\n",
      "Epoch [9/20], Step [80/99], Loss: 0.5523\n",
      "Epoch [10/20], Step [20/99], Loss: 0.7264\n",
      "Epoch [10/20], Step [40/99], Loss: 0.4253\n",
      "Epoch [10/20], Step [60/99], Loss: 0.8944\n",
      "Epoch [10/20], Step [80/99], Loss: 0.7511\n",
      "Epoch [11/20], Step [20/99], Loss: 0.2898\n",
      "Epoch [11/20], Step [40/99], Loss: 0.2340\n",
      "Epoch [11/20], Step [60/99], Loss: 1.3002\n",
      "Epoch [11/20], Step [80/99], Loss: 0.4056\n",
      "Epoch [12/20], Step [20/99], Loss: 0.9600\n",
      "Epoch [12/20], Step [40/99], Loss: 0.9047\n",
      "Epoch [12/20], Step [60/99], Loss: 0.2843\n",
      "Epoch [12/20], Step [80/99], Loss: 0.5169\n",
      "Epoch [13/20], Step [20/99], Loss: 0.8492\n",
      "Epoch [13/20], Step [40/99], Loss: 0.5095\n",
      "Epoch [13/20], Step [60/99], Loss: 0.6430\n",
      "Epoch [13/20], Step [80/99], Loss: 1.2651\n",
      "Epoch [14/20], Step [20/99], Loss: 0.9297\n",
      "Epoch [14/20], Step [40/99], Loss: 0.7597\n",
      "Epoch [14/20], Step [60/99], Loss: 0.4493\n",
      "Epoch [14/20], Step [80/99], Loss: 0.9000\n",
      "Epoch [15/20], Step [20/99], Loss: 0.2131\n",
      "Epoch [15/20], Step [40/99], Loss: 0.8448\n",
      "Epoch [15/20], Step [60/99], Loss: 0.6641\n",
      "Epoch [15/20], Step [80/99], Loss: 1.3297\n",
      "Epoch [16/20], Step [20/99], Loss: 0.6168\n",
      "Epoch [16/20], Step [40/99], Loss: 0.4628\n",
      "Epoch [16/20], Step [60/99], Loss: 0.7814\n",
      "Epoch [16/20], Step [80/99], Loss: 0.5469\n",
      "Epoch [17/20], Step [20/99], Loss: 0.1502\n",
      "Epoch [17/20], Step [40/99], Loss: 1.7633\n",
      "Epoch [17/20], Step [60/99], Loss: 1.3834\n",
      "Epoch [17/20], Step [80/99], Loss: 0.4703\n",
      "Epoch [18/20], Step [20/99], Loss: 0.5782\n",
      "Epoch [18/20], Step [40/99], Loss: 1.1975\n",
      "Epoch [18/20], Step [60/99], Loss: 0.7512\n",
      "Epoch [18/20], Step [80/99], Loss: 0.9541\n",
      "Epoch [19/20], Step [20/99], Loss: 0.5763\n",
      "Epoch [19/20], Step [40/99], Loss: 0.2392\n",
      "Epoch [19/20], Step [60/99], Loss: 0.7182\n",
      "Epoch [19/20], Step [80/99], Loss: 1.0052\n",
      "Epoch [20/20], Step [20/99], Loss: 0.7503\n",
      "Epoch [20/20], Step [40/99], Loss: 0.5545\n",
      "Epoch [20/20], Step [60/99], Loss: 0.4499\n",
      "Epoch [20/20], Step [80/99], Loss: 0.8240\n",
      "训练完成\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# 训练模型\n",
    "total_step = len(loader)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (sequences, labels) in enumerate(loader):\n",
    "        # print(i, (sequences, labels))\n",
    "        # input()\n",
    "        sequences = sequences.view(-1, len(sequences[0]), input_size).float().to(device) # 将输入形状调整为(batch_size, sequence_length, input_size)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # 正向传播\n",
    "        outputs = model(sequences)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # 反向传播和优化\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (i+1) % 20 == 0:\n",
    "            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))\n",
    "\n",
    "print('训练完成')\n",
    "\n",
    "# 在测试集上评估模型\n",
    "# ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4c2bf0a4-509e-42d5-b778-58ac71b35020",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载测试数据集\n",
    "test_dataset = CustomDataset('/root/LSTM2/test.csv')  # 使用测试集的CSV文件路径\n",
    "# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fe2dec3d-c556-4473-bcbb-173cbc203941",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "在测试集上的准确率: 83.81%\n"
     ]
    }
   ],
   "source": [
    "# 将模型设置为评估模式\n",
    "model.eval()\n",
    "\n",
    "total_correct = 0\n",
    "total_samples = 0\n",
    "with torch.no_grad():\n",
    "    for sequences, labels in test_loader:\n",
    "        sequences = sequences.view(-1, len(sequences[0]), input_size).float().to(device) # 将输入形状调整为(batch_size, sequence_length, input_size)\n",
    "        # sequences = sequences.unsqueeze(2).to(device)  # 添加一个维度，将序列变成三维张量\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # 正向传播\n",
    "        outputs = model(sequences)\n",
    "        # probabilities = F.softmax(outputs, dim=1)  # 应用softmax变换\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "\n",
    "        # 统计正确预测的样本数量\n",
    "        total_correct += (predicted == labels).sum().item()\n",
    "        total_samples += labels.size(0)\n",
    "\n",
    "# 计算准确率\n",
    "accuracy = total_correct / total_samples\n",
    "print('在测试集上的准确率: {:.2f}%'.format(accuracy * 100))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "28377cf5-1cd5-4411-b78b-f0402787b9d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测的类别索引为: 0\n"
     ]
    }
   ],
   "source": [
    "# 假设您有一个名为 sequence_to_predict 的序列，它是一个列表\n",
    "sequence_to_predict = [1]\n",
    "\n",
    "# 转换成 PyTorch 张量，并添加批次维度\n",
    "sequence_tensor = torch.tensor(sequence_to_predict).unsqueeze(0).unsqueeze(-1).float().to(device)\n",
    "\n",
    "# 使用模型进行预测\n",
    "with torch.no_grad():\n",
    "    # 将模型设置为评估模式\n",
    "    model.eval()\n",
    "    # 进行预测\n",
    "    output = model(sequence_tensor)\n",
    "    # probabilities = F.softmax(outputs, dim=1)  # 应用softmax变换\n",
    "    \n",
    "    #print(outputs)\n",
    "    # 获取预测结果\n",
    "    _, predicted_class = torch.max(outputs, 1)\n",
    "\n",
    "# predicted_class 包含了模型预测的类别索引，您可以根据实际情况进行解释\n",
    "print(\"预测的类别索引为:\", predicted_class.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07a3cfb2-caaa-4b63-b363-a32fb8c587ab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62af889-090f-433d-87ad-ab651fd692db",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
